import random
from policies.utils.transition import Transition

class ReplayBuffer:
    """ Replay buffer stores data generated from interactions between the policy and environment
    """
    def __init__(self):
        self.memory = []

    def add(self, transition: Transition):
        self.memory.append(transition)

    def sample_batch(self, batch_size=1):
        trajectory_length = len(self.memory)

        # If the trajectory length is smaller than or equal to the batch size, return the whole trajectory
        if trajectory_length <= batch_size:
            return self.memory

        # Otherwise, sample a starting point and retur.0n a batch of length batch_size
        start_point = random.randint(0, trajectory_length - batch_size)
        return self.memory[start_point:start_point + batch_size]
    
    def update(self):
        return

    def clear(self):
        self.memory = []
    
    def len(self):
        return len(self.memory)

    def __len__(self):
        return len(self.memory)
    
    def __getitem__(self, idx):
        return self.memory[idx]